{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ONNX Script 模式重写进阶\n",
    "\n",
    "参考：[pattern_rewriting](https://github.com/microsoft/onnxscript/blob/main/examples/pattern_rewriting.py)\n",
    "\n",
    "本节展示了如何基于模式定义重写规则。\n",
    "\n",
    "```{topic} 主题\n",
    "目标是将 ONNX 模型中的一些节点替换为另一系列更高效的节点。\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import onnx\n",
    "import onnx.helper as oh\n",
    "import onnx.numpy_helper as onh\n",
    "\n",
    "import onnxscript\n",
    "from onnxscript import ir\n",
    "from onnxscript.rewriter import generic_pattern"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义简单模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def get_rotary_model(bad_model=False):\n",
    "    inputs = [\n",
    "        oh.make_tensor_value_info(\"x\", onnx.TensorProto.INT64, shape=[]),\n",
    "        oh.make_tensor_value_info(\"pos_ids\", onnx.TensorProto.FLOAT, shape=[]),\n",
    "        oh.make_tensor_value_info(\"axis\", onnx.TensorProto.INT64, shape=[]),\n",
    "    ]\n",
    "    nodes = [\n",
    "        oh.make_node(\"Unsqueeze\", [\"x\", \"axis\"], [\"_onx_unsqueeze0\"]),\n",
    "        oh.make_node(\"Cast\", [\"_onx_unsqueeze0\"], [\"_onx_cast0\"], to=1),\n",
    "        oh.make_node(\"MatMul\", [\"pos_ids\", \"_onx_cast0\"], [\"_onx_matmul0\"]),\n",
    "        oh.make_node(\"Transpose\", [\"_onx_matmul0\"], [\"_onx_transpose0\"]),\n",
    "        oh.make_node(\n",
    "            \"ConcatTrainingBad\" if bad_model else \"ConcatTraining\",\n",
    "            [\"_onx_transpose0\", \"_onx_transpose0\"],\n",
    "            [\"_onx_concattraining0\", \"_onx_concattraining1\"],\n",
    "            domain=\"com.microsoft\",\n",
    "        ),\n",
    "        oh.make_node(\"Sin\", [\"_onx_concattraining0\"], [\"_onx_sin0\"]),\n",
    "        oh.make_node(\"Cast\", [\"_onx_sin0\"], [\"_onx_cast02\"], to=1),\n",
    "        oh.make_node(\"Cos\", [\"_onx_concattraining0\"], [\"_onx_cos0\"]),\n",
    "        oh.make_node(\"Cast\", [\"_onx_cos0\"], [\"_onx_cast03\"], to=1),\n",
    "    ]\n",
    "    outputs = [\n",
    "        oh.make_tensor_value_info(\"_onx_cast02\", onnx.TensorProto.UNDEFINED, []),\n",
    "        oh.make_tensor_value_info(\"_onx_cast03\", onnx.TensorProto.UNDEFINED, []),\n",
    "    ]\n",
    "    model = oh.make_model(\n",
    "        oh.make_graph(\n",
    "            nodes,\n",
    "            \"experiment\",\n",
    "            inputs,\n",
    "            outputs,\n",
    "        ),\n",
    "        opset_imports=[\n",
    "            oh.make_opsetid(\"\", 18),\n",
    "            oh.make_opsetid(\"com.microsoft\", 18),\n",
    "        ],\n",
    "    )\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = get_rotary_model()\n",
    "ir_model = ir.serde.deserialize_model(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 重写模式(ONNX)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "op = onnxscript.opset18\n",
    "msft_op = onnxscript.values.Opset(\"com.microsoft\", 1)\n",
    "\n",
    "\n",
    "def rotary_match_pattern(x, pos_ids, axis):\n",
    "    \"\"\"The pattern to match.\"\"\"\n",
    "    unsqueeze = op.Unsqueeze(x, axis)\n",
    "    cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT)\n",
    "\n",
    "    matmul = op.MatMul(pos_ids, cast)\n",
    "    transpose = op.Transpose(matmul)\n",
    "    output, length = msft_op.ConcatTraining(transpose, transpose)\n",
    "\n",
    "    sin = op.Sin(output)\n",
    "    cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT)\n",
    "    cos = op.Cos(output)\n",
    "    cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT)\n",
    "    return cast1, cast2\n",
    "\n",
    "\n",
    "def validate_rotary_mapping(g, match_result) -> bool:\n",
    "    \"\"\"The validation post matching.\n",
    "\n",
    "    Returns True to validate the replacement,\n",
    "    False not to apply it.\n",
    "\n",
    "    :param g: model\n",
    "    :param match_result: matched nodes\n",
    "    \"\"\"\n",
    "    del g\n",
    "    del match_result\n",
    "    return True\n",
    "\n",
    "\n",
    "def rotary_apply_pattern(x, pos_ids, axis):\n",
    "    \"\"\"The replacement pattern.\"\"\"\n",
    "    cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))\n",
    "    sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))\n",
    "    part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache)\n",
    "    return part1, part2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建规则"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "'RotaryEmbedding' is not a known op in 'com.microsoft'\n",
      "'ConcatTraining' is not a known op in 'com.microsoft'\n"
     ]
    }
   ],
   "source": [
    "rule_with_validation_function = generic_pattern.make_pattern_rule(\n",
    "    rotary_match_pattern,\n",
    "    rotary_apply_pattern,\n",
    "    validate_rotary_mapping,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`validate_rotary_mapping` 函数总是返回 `True`。在这种情况下，可以忽略这个参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "'RotaryEmbedding' is not a known op in 'com.microsoft'\n",
      "rotary_apply_pattern: Already defined.\n",
      "'ConcatTraining' is not a known op in 'com.microsoft'\n",
      "rotary_match_pattern: Already defined.\n"
     ]
    }
   ],
   "source": [
    "rule = generic_pattern.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "应用规则："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rule.apply_to_model(ir_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "重写模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "rewritten_model = ir.serde.serialize_model(ir_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看运行情况："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Constant() -> val_0\n",
      "Constant() -> val_1\n",
      "RotaryEmbedding(x, pos_ids, cos_cache, sin_cache) -> val_2, val_3\n"
     ]
    }
   ],
   "source": [
    "for node in rewritten_model.graph.node:\n",
    "    print(f\"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果它失败了呢？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Unsqueeze', 'Cast', 'MatMul', 'Transpose', 'ConcatTrainingBad', 'Sin', 'Cast', 'Cos', 'Cast']\n"
     ]
    }
   ],
   "source": [
    "model = get_rotary_model(True)\n",
    "ir_model = ir.serde.deserialize_model(model)\n",
    "\n",
    "rule.apply_to_model(ir_model)\n",
    "rewritten_model = ir.serde.serialize_model(ir_model)\n",
    "\n",
    "print([n.op_type for n in rewritten_model.graph.node])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "匹配没有发生，我们可以增加细节。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "'RotaryEmbedding' is not a known op in 'com.microsoft'\n",
      "rotary_apply_pattern: Already defined.\n",
      "'ConcatTraining' is not a known op in 'com.microsoft'\n",
      "rotary_match_pattern: Already defined.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[GenericPattern.match] starts with %\"_onx_cast0\"<?,?> ⬅️ ::Cast(%\"_onx_unsqueeze0\") {to=1}\n",
      "[GenericPattern.match] match pattern <onnxscript.rewriter.generic_pattern.FunctionPattern object at 0x7fbe13cc1340>\n",
      "[GenericPattern.match] iteration=1 n_matched=1, n_stack=1, matched_types=Counter({'Cast': 1})\n",
      "[FunctionPattern.match] NONE - line: 460:onnxscript.rewriter.generic_pattern, op_type=Cast\n",
      "    --hint--: BACKWARD: different node types\n",
      "      %\"cos\"<?,?> ⬅️ ::Cos(%\"output\")\n",
      "      %\"_onx_unsqueeze0\"<?,?> ⬅️ ::Unsqueeze(%\"x\", %\"axis\")\n",
      "    iteration=0\n",
      "    --matched-- #1\n",
      "      Cast(%\"cos\"<?,?>) ~ Cast(%\"_onx_unsqueeze0\"<?,?>) [140454351111376-140454352338048]\n",
      "    len(stack)=0:[]\n",
      "[GenericPattern.match] done. backward failed.\n",
      "[GenericPattern.match] starts with %\"_onx_cast02\"<UNDEFINED,[]> ⬅️ ::Cast(%\"_onx_sin0\") {to=1}\n",
      "[GenericPattern.match] match pattern <onnxscript.rewriter.generic_pattern.FunctionPattern object at 0x7fbe13cc1340>\n",
      "[GenericPattern.match] iteration=1 n_matched=1, n_stack=1, matched_types=Counter({'Cast': 1})\n",
      "[FunctionPattern.match] NONE - line: 460:onnxscript.rewriter.generic_pattern, op_type=Cast\n",
      "    --hint--: BACKWARD: different node types\n",
      "      %\"cos\"<?,?> ⬅️ ::Cos(%\"output\")\n",
      "      %\"_onx_sin0\"<?,?> ⬅️ ::Sin(%\"_onx_concattraining0\")\n",
      "    iteration=0\n",
      "    --matched-- #1\n",
      "      Cast(%\"cos\"<?,?>) ~ Cast(%\"_onx_sin0\"<?,?>) [140454351111376-140454352338768]\n",
      "    len(stack)=0:[]\n",
      "[GenericPattern.match] done. backward failed.\n",
      "[GenericPattern.match] starts with %\"_onx_cast03\"<UNDEFINED,[]> ⬅️ ::Cast(%\"_onx_cos0\") {to=1}\n",
      "[GenericPattern.match] match pattern <onnxscript.rewriter.generic_pattern.FunctionPattern object at 0x7fbe13cc1340>\n",
      "[GenericPattern.match] iteration=1 n_matched=1, n_stack=1, matched_types=Counter({'Cast': 1})\n",
      "[GenericPattern._match_backward] match Cos((Value('_onx_concattraining0', type=None, shape=None, producer=anonymous_node:140454352338480, index=0),)) with Cos((Value('output', type=None, shape=None, producer=n4, index=0),)) (pattern)\n",
      "[GenericPattern._match_backward] add 1 nodes\n",
      "[GenericPattern.match] iteration=2 n_matched=2, n_stack=1, matched_types=Counter({'Cast': 1, 'Cos': 1})\n",
      "[FunctionPattern.match] NONE - line: 460:onnxscript.rewriter.generic_pattern, op_type=Cast\n",
      "    --hint--: BACKWARD: different node types\n",
      "      %\"output\"<?,?>, %\"length\"<?,?> ⬅️ com.microsoft::ConcatTraining(%\"transpose\", %\"transpose\")\n",
      "      %\"_onx_concattraining0\"<?,?>, %\"_onx_concattraining1\"<?,?> ⬅️ com.microsoft::ConcatTrainingBad(%\"_onx_transpose0\", %\"_onx_transpose0\")\n",
      "    iteration=1\n",
      "    --matched-- #2\n",
      "      Cast(%\"cos\"<?,?>) ~ Cast(%\"_onx_cos0\"<?,?>) [140454351111376-140454351110224]\n",
      "      Cos(%\"output\"<?,?>) ~ Cos(%\"_onx_concattraining0\"<?,?>) [140454351111088-140455237148048]\n",
      "    len(stack)=0:[]\n",
      "[GenericPattern.match] done. backward failed.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rule = generic_pattern.make_pattern_rule(\n",
    "    rotary_match_pattern, rotary_apply_pattern, verbose=10\n",
    ")\n",
    "\n",
    "rule.apply_to_model(ir_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "日志显示了每次算法拒绝模式的情况。\n",
    "\n",
    "可能的信息为：\n",
    "\n",
    "```\n",
    "::\n",
    "\n",
    "    [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast\n",
    "        --hint--: BACKWARD: different node types\n",
    "          --pattern\n",
    "          ConcatTraining(transpose, transpose) -> (output, length)\n",
    "          -- model\n",
    "          ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1)\n",
    "        iteration=1\n",
    "        --marked-- #2\n",
    "          Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320]\n",
    "          Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472]\n",
    "        len(stacked)=0:[]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在文件 `generic_pattern.py` 的第673行，匹配被拒绝了。它表示在向后方向比较两个节点时，节点类型不匹配。它还表示实际上有两个节点是匹配的。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
